import sys
import copy
import random
import pymongo
from scipy.stats import pearsonr

META_TYPES = [
    "pehash"
]

LABEL_TYPES = [
    "avclass_default_prep_0"
]

CORRELATION_LABEL_TYPES = [
    "avclass_default_prep_0",
    "avclass_alias_prep_0",
    "avclass_genav_remove_0"
]


def get_precision(C, D, md5_D):
    precision = 0.0
    m = 0.0
    for md5s in C.values():
        precision += max([len(md5s.intersection(D[md5_D[md5]])) for md5 in md5s])
        m += len(md5s)
    precision /= m
    return precision


def get_recall(C, D, md5_C):
    recall = 0.0
    m = 0.0
    for md5s in D.values():
        recall += max([len(md5s.intersection(C[md5_C[md5]])) for md5 in md5s])
        m += len(md5s)
    recall /= m
    return recall


def print_cluster_stats(clusters):
    # Compute cluster stats
    count = 0
    num_clusters = len(clusters.keys())
    num_singletons = 0
    for key, vals in clusters.items():
        count += len(vals)
        if len(vals) == 1:
            num_singletons += 1
    max_cluster = max(clusters.items(), key=lambda l:len(l[1]))

    # Print cluster stats
    print("Number of values: {}".format(count))
    print("Number of clusters: {}".format(num_clusters))
    print("Number of singletons: {}".format(num_singletons))
    print("Average cluster size: {:.3f}".format(float(count) / num_clusters))
    print("Largest cluster size: {}".format(len(max_cluster[1])))
    print("Most common value: {}".format(max_cluster[0]))

    return


def print_pearsonr(M, C, D, md5_C, md5_D):
    M_list = list(M)
    C_shuffled = copy.deepcopy(C)
    md5_C_shuffled = copy.deepcopy(md5_C)
    step = len(M) // 20
    precisions = []
    recalls = []

    for i in range(len(M) - 1, 0, -1): #1 or 0?
        j = random.randint(0, i)
        md5_1 = M_list[i]
        md5_2 = M_list[j]
        meta_hash_1 = md5_C_shuffled[md5_1]
        meta_hash_2 = md5_C_shuffled[md5_2]
        if md5_1 == md5_2:
            continue
        if meta_hash_1 == meta_hash_2:
            continue

        # remove md5_1 from set, add 2 and vice versa
        md5_C_shuffled[md5_1] = meta_hash_2
        md5_C_shuffled[md5_2] = meta_hash_1
        C_shuffled[meta_hash_1].remove(md5_1)
        C_shuffled[meta_hash_1].add(md5_2)
        C_shuffled[meta_hash_2].remove(md5_2)
        C_shuffled[meta_hash_2].add(md5_1)

        if i % step == 0 or i == 1:
            print(i)
            precisions.append(get_precision(C_shuffled, D, md5_D))
            recalls.append(get_recall(C_shuffled, D, md5_C_shuffled))

    p_correlation, p_pvalue = pearsonr(range(len(precisions)), precisions)
    r_correlation, r_pvalue = pearsonr(range(len(recalls)), recalls)
    print("Precision correlation:", p_correlation, p_pvalue)
    print("Recall correlation:", r_correlation, r_pvalue)
    return


if __name__ == "__main__":
    M = set()
    md5_R = {}
    md5_C = {}

    # Load everything from pymongo db
    print("Loading from pymongo")
    sys.stdout.flush()
    client = pymongo.MongoClient("127.0.0.1", 27017)
    db = client["agtr_db"]    
    cursor = db["metadata"].find({})
    for document in cursor:
        md5 = document["md5"]
        M.add(md5)
        if document["pehash"] is not None:
            md5_R[md5] = document["pehash"]
        else:
            md5_R[md5] = document["md5"]

    # Load AVClass labels from db
    client = pymongo.MongoClient("127.0.0.1", 27017)
    db = client["agtr_db"]
    print("Loading labels from pymongo")
    sys.stdout.flush()
    cursor = db["avclass_default_prep"].find({})
    for document in cursor:
        md5 = document["md5"]
        if md5 not in M:
            continue
        if document.get("avclass_default_prep_0") is not None:
            md5_C[md5] = document["avclass_default_prep_0"]
        else:
            md5_C[md5] = md5

    # Fill in any missing labels
    for md5 in M:
        if md5_C.get(md5) is None:
            md5_C[md5] = md5

    # Construct AGTR
    R = {}
    for md5, meta_hash in md5_R.items():
        if R.get(meta_hash) is None:
            R[meta_hash] = set()
        R[meta_hash].add(md5)
        md5_R[md5] = meta_hash

    print_cluster_stats(R)

    # Unlabeled samples not counted
    total_samples = 0
    labeled_samples = 0
    total_clusters = 0
    broken_clusters = 0
    broken_cluster_samples = 0
    clusters = list(R.values())
    random.shuffle(clusters)
    for md5s in clusters:
        total_samples += len(md5s)
        label_set = set()
        for md5 in md5s:
            label = md5_C[md5]
            if label == md5:
                continue
            labeled_samples += 1
            label_set.add(label)
        if len(label_set) > 1:
            broken_clusters += 1
            broken_cluster_samples += len(md5s)
            for md5 in md5s:
                query = {"md5": md5}
                res = db["rj_labels"].find_one(query)
                print(md5, res["token_counts"])
            print()
        if broken_clusters >= 100:
            break

    print("Total samples:", total_samples)
    print("Labeled samples:", labeled_samples)
    print("Total clusters:", total_clusters)
    print("Broken clusters:", broken_clusters)
    print("Samples from broken clusters:", broken_cluster_samples)

    # Create metadata hash AGTR
    for meta_type in ["pehash"]:
        threshold = None
        if meta_type.startswith("imphash_"):
            threshold = int(meta_type[8:])
            meta_type = "imphash"
        md5_R = meta[meta_type]
        if threshold is not None:
            md5_R = copy.deepcopy(meta[meta_type])

        R = {}
        for md5, meta_hash in md5_R.items():
            if threshold is not None:
                num_imports = meta["num_imports"][md5]
                if not isinstance(num_imports, int):
                    meta_hash = md5
                elif num_imports < threshold:
                    meta_hash = md5
            if R.get(meta_hash) is None:
                R[meta_hash] = set()
            R[meta_hash].add(md5)
            md5_R[md5] = meta_hash

        print(meta_type)
        print_cluster_stats(R)
        for label_type in LABEL_TYPES:
            precision = get_precision(C[label_type], R, md5_R)
            recall = get_recall(C[label_type], R, labels[label_type])
            fmt = "{}\t precision: {:.5f}\trecall: {:.5f}"
            print(fmt.format(label_type, precision, recall))
            if label_type in CORRELATION_LABEL_TYPES:
                print_pearsonr(M, C[label_type], R, labels[label_type], md5_R)
        print()
